Setup

library(tidyverse)
library(knitr)
library(ggpubr)

library(survival)

library(SemiCompRisks)
library(SemiCompRisksFreq)
library(splines2)

library(cmprsk)
library(survival)
library(rlist)

# for all the metrics 
library(pec)
library(pROC)
library(mccr)



# use Harrison's color coding 
#color blind friendly colors from here: https://davidmathlogic.com/colorblind/#%23648FFF-%23785EF0-%23DC267F-%23FE6100-%23FFB000
cb_blue <- "#648FFF"; cb_red <- "#DC267F"; cb_purple <- "#785EF0"; cb_orange <- "#FE6100"; cb_grey <- "#CACACA"

#pink and green color-blind friendly color palette, with two colors (pink and green)
#each in two shades: light and dark
four_color_paired <- RColorBrewer::brewer.pal(n=4,name="PiYG")[c(3,4,2,1)]


#other colors used in plots
two_color_cb <- c(cb_blue,cb_red)

three_color <- c("dodgerblue","firebrick3","purple3")
three_color_cb <- c(cb_blue,cb_red,cb_purple)

#take three cb colors and reduce V from 100 to 60
three_color_cb_dark <- c("#3b5699", "#751342", "#453589")

four_color <- c("lightgray","firebrick3","purple3","dodgerblue")
four_color_cb <- c(cb_grey,cb_red,cb_purple,cb_blue)
four_color_forest <- c("dodgerblue","firebrick3","purple3","magenta")
four_color_forest_cb <- c(cb_blue,cb_red,cb_purple,cb_orange)
five_color <- c("lightgray","firebrick3","magenta","purple3","dodgerblue")
five_color_cb <- c(cb_grey,cb_red,cb_orange,cb_purple,cb_blue)

four_color_grey <- c("grey80","grey60","grey40","grey20")
five_color_grey <- c("grey80","grey60","grey40","grey20","black")

# RColorBrewer::display.brewer.all(n=4,colorblindFriendly = TRUE)
# color-blind friendly categorical colors
three_color_qual <- RColorBrewer::brewer.pal(n=3,name="Set2")
four_color_qual <- RColorBrewer::brewer.pal(n=4,name="Dark2")
five_color_qual <- RColorBrewer::brewer.pal(n=5,name="Dark2")

Functions

# define basis function for piecewise effect of non-terminal time on terminal time 
h3_b_fun <- function(x) splines2::bSpline(x,
                                          knots = h3tv_knots_temp,
                                          Boundary.knots = c(0,Inf),
                                          degree = 0,intercept = FALSE)


# function to plot risk profile 
plot_risk_profile <- function(tseq_pred,
                              ValPred, ValDat = valDat, 
                              i, 
                              colors = four_color_cb, 
                              t_end) {
  length_tseq <- length(tseq_pred)
  
  # format data frame 
  plot_frame <- data.frame(
    Time = rep(tseq_pred, 4),
    Probability = c(ValPred$p_neither[, i], 
                    ValPred$p_term_only[, i],
                    ValPred$p_both[, i],
                    ValPred$p_nonterm_only[, i]),
    Outcome=factor(x = c(rep("Neither",length_tseq),
                         rep("Terminal Only",length_tseq),
                         rep("Both",length_tseq),
                         rep("Nonterminal Only",length_tseq)),
                   levels = c("Neither",
                              "Terminal Only",
                              "Both",
                              "Nonterminal Only"))
  )
  
    g <- ggplot() +
    geom_area(plot_frame, 
              mapping = aes(x = Time, 
                  y = Probability, 
                  colour = Outcome, 
                  fill = Outcome)) +
    scale_color_manual(values = colors) +
    scale_fill_manual(values = colors) + 
    theme_bw() +
    labs(title = paste0("Patient ", i, collapse = ""))
    
    if(ValDat[i, ]$event1 == 1) {
    trueT1 <- ValDat[i, ]$time1
    
    g <- g +
      geom_step(data = data.frame(time = c(0, trueT1, t_end), 
                                  ind = c(0, 1, 1)), 
                aes(x = time, 
                    y = ind), 
                 linetype = "dashed", 
                 color = "red")
    
  }
  if(ValDat[i, ]$event2 == 1) {
    trueT2 <- ValDat[i, ]$time2
    
     g <- g +
      geom_step(data = data.frame(time = c(0, trueT2, t_end), 
                                  ind = c(0, 1, 1)), 
                aes(x = time, 
                    y = ind), 
                linetype = "dashed", 
                 color = "black")
  }

    return(g)
}


plot_RMTL_profile <- function(tseq_pred,
                              ValPred, ValDat = valDat, 
                              i, 
                              colors = four_color_cb) {
  
  RMTL <- data.frame(time = tseq_pred, 
                     RMTL_NT = c(0, rep(NA, length(tseq_pred) - 1)), 
                     RMTL_T = c(0, rep(NA, length(tseq_pred)- 1)), 
                     RMTL_NT_T = c(0, rep(NA, length(tseq_pred) - 1)), 
                     RMTL_HA = c(0, rep(NA, length(tseq_pred) - 1))
  )
  
  for (t in tseq_pred[-1]) {
    
    tseq_temp <- ValPred$tseq[ValPred$tseq <= t]
    P_nonterm_temp <- ValPred$p_nonterm_only[seq(1, length(tseq_temp)), i]
    P_term_temp <- ValPred$p_term_only[seq(1, length(tseq_temp)), i]
    P_both_temp <- ValPred$p_both[seq(1, length(tseq_temp)), i]
    P_neither_temp <- ValPred$p_neither[seq(1, length(tseq_temp)), i]
    
    RMTL[t + 1, ]$RMTL_NT <- as.vector(c(0.5, diff(tseq_temp)[-1], 0.5) %*% 
                                         P_nonterm_temp)
    
    RMTL[t + 1, ]$RMTL_T <- as.vector(c(0.5, diff(tseq_temp)[-1], 0.5) %*% 
                                        P_term_temp) 
    
    RMTL[t + 1, ]$RMTL_NT_T <- as.vector(c(0.5, diff(tseq_temp)[-1], 0.5) %*% 
                                           P_both_temp) 
    
    RMTL[t + 1, ]$RMTL_HA <- as.vector(c(0.5, diff(tseq_temp)[-1], 0.5) %*% 
                                         P_neither_temp) 
  }
  
  RMTL_df <- RMTL %>% 
    gather(key = "name", 
           value = "RMTL", 
           -time)
  
  # plot RMTL over time 
  g <- ggplot(data = RMTL_df, 
              aes(x = time, 
                  y = RMTL,
                  color = name, 
                  fill = name)) +
    geom_area(position = "stack") +
    # geom_line() +
    theme_bw() +
    theme(legend.position = "bottom") +
    scale_color_manual(values = colors) +
    scale_fill_manual(values = colors) + 
    labs(color = "", fill = "", 
         x = "Truncation time", 
         title = paste0("Patient ", i, collapse = ""))
  
  if(ValDat[i, ]$event1 == 1) {
    trueT1 <- ValDat[i, ]$time1
    
        g <- g +
      geom_vline(aes(xintercept = trueT1), 
                 linetype = "dashed", 
                 color = "red")
    
  }
  if(valDat[i, ]$event2 == 1) {
    trueT2 <- ValDat[i, ]$time2
    
     g <- g +
      geom_vline(aes(xintercept = trueT2), 
                 linetype = "dashed", 
                 color = "black")
  }

    return(g)
}

Load Data

# load data file 
load(file = "Data/SimData_Training.RData")
load(file = "Data/SimData_Validation.RData")


load(file = "Data/TrueParameter_List.RData")

# extract true parameters
x_sim <- paramList$x1

beta1_true <- paramList$beta1.true
beta2_true <- paramList$beta2.true
beta3_true <- paramList$beta3.true

alpha1_true <- paramList$alpha1.true
alpha2_true <- paramList$alpha2.true
alpha3_true <- paramList$alpha3.true


kappa1_true <- paramList$kappa1.true
kappa2_true <- paramList$kappa2.true
kappa3_true <- paramList$kappa3.true

h3tv_degree <- paramList$h3tv_degree 
h3tv_knots <- paramList$h3tv_knots
beta3tv_true <- paramList$beta3tv.true


theta_true <- paramList$theta

Shared frailty Illness-Death model with piecewise baseline hazard

The baseline hazard is modeled to follow a piecewise form, with 20 knots.

# define formula 
form_temp_no_t1cat <- Formula::Formula(time1 + event1 | time2 + event2 ~ x1 + x2 + x3 | x1 + x2 + x3 | x1 + x2 + x3 )

SFID_no_t1cat <- 
  # SemiCompRisksFreq::
  
  FreqID_HReg2(Formula = form_temp_no_t1cat, 
               data = trainDat,
               hazard = "pw", 
               model = "semi-Markov",
               # number of baseline parameters 
               nP0 = c(10, 10, 10),
               frailty = TRUE, 
               optim_method = "BFGS",
               extra_starts = 0)

SFID_no_t1cat %>% summary()
## 
## Analysis of independent semi-competing risks data 
## Piecewise Constant baseline hazard specification
## semi-Markov specification for h3
## Confidence level: 95%
## 
## Variance of frailties:
##    theta       SE       LL       UL   lrtest lrpvalue 
##    1.021    0.086    0.866    1.205  279.055    0.000 
## SE(theta) computed from SE(log(theta)) via delta method.
## LL and UL computed on scale of log(theta) and exponentiated,
## e.g., 95% UL = exp(log(theta) + 1.96 * SE(log(theta))).
## Likelihood ratio test of theta=0 vs. theta>0 using mixture of chi-squareds null.
## 
## Hazard ratios:
##    exp(beta1)    LL    UL exp(beta2)    LL    UL exp(beta3)    LL    UL
## x1      1.376 1.265 1.496      1.601 1.434 1.788      1.956 1.771 2.161
## x2      2.020 1.841 2.217      2.448 2.169 2.762      2.501 2.239 2.794
## x3      0.758 0.697 0.825      0.639 0.573 0.714      0.424 0.381 0.471
## 
## Baseline hazard function components:
##                            h1-PM    SE     LL     UL  h2-PM    SE     LL     UL
## Piecewise Constant: phi1  -3.057 0.098 -3.248 -2.865 -3.967 0.155 -4.270 -3.664
## Piecewise Constant: phi2  -3.073 0.096 -3.261 -2.885 -4.037 0.150 -4.331 -3.743
## Piecewise Constant: phi3  -2.970 0.097 -3.159 -2.780 -4.120 0.147 -4.408 -3.832
## Piecewise Constant: phi4  -3.107 0.099 -3.301 -2.913 -3.878 0.150 -4.171 -3.584
## Piecewise Constant: phi5  -2.901 0.103 -3.104 -2.699 -3.861 0.150 -4.154 -3.568
## Piecewise Constant: phi6  -3.012 0.108 -3.224 -2.800 -4.086 0.151 -4.383 -3.789
## Piecewise Constant: phi7  -2.780 0.116 -3.007 -2.553 -3.809 0.158 -4.118 -3.500
## Piecewise Constant: phi8  -2.948 0.126 -3.195 -2.700 -3.771 0.167 -4.098 -3.444
## Piecewise Constant: phi9  -2.837 0.144 -3.120 -2.554 -3.614 0.177 -3.961 -3.267
## Piecewise Constant: phi10 -2.782 0.172 -3.119 -2.445 -3.789 0.201 -4.183 -3.394
##                            h3-PM    SE     LL     UL
## Piecewise Constant: phi1  -3.306 0.121 -3.544 -3.068
## Piecewise Constant: phi2  -3.074 0.117 -3.303 -2.845
## Piecewise Constant: phi3  -3.253 0.115 -3.478 -3.029
## Piecewise Constant: phi4  -3.325 0.114 -3.548 -3.103
## Piecewise Constant: phi5  -3.353 0.115 -3.578 -3.128
## Piecewise Constant: phi6  -3.426 0.116 -3.653 -3.199
## Piecewise Constant: phi7  -3.241 0.117 -3.471 -3.012
## Piecewise Constant: phi8  -3.261 0.120 -3.497 -3.025
## Piecewise Constant: phi9  -3.339 0.124 -3.583 -3.096
## Piecewise Constant: phi10 -3.205 0.130 -3.460 -2.949
## 
## Knots:
##            h1     h2     h3
## knot1   0.000  0.000  0.000
## knot2   1.085  0.829  0.992
## knot3   2.628  2.082  2.095
## knot4   4.488  4.045  3.921
## knot5   7.458  6.132  6.700
## knot6  10.819  9.016 10.664
## knot7  16.256 14.613 16.780
## knot8  22.642 21.199 23.889
## knot9  34.828 31.603 34.493
## knot10 55.930 48.344 52.411

Visualize fitted shared frailty Illness-Death model

We are restricting ourselves to a time window 0 to 100. Since no censoring was modeled, the terminal event time for some patients might be very high.

# estimate for a certain time window 
tseq_pred <- seq(0, 100)

SFID_no_t1cat_pred <- predict(SFID_no_t1cat, 
                              tseq = tseq_pred)

plot(SFID_no_t1cat_pred, plot.est = "Haz")

plot(SFID_no_t1cat_pred, plot.est = "CumHaz")

plot(SFID_no_t1cat_pred, plot.est = "Surv")

Forest plot for model coefficients

The covariate coefficients for the marginal and conditional hazard functions are accurately estimated.

SFID_coef <- summary(SFID_no_t1cat)$coef_long %>% 
  as.data.frame() %>% 
  mutate(name = c("beta1.1", "beta1.2", "beta1.3", 
                  "beta2.1", "beta2.2", "beta2.3", 
                  "beta3.1", "beta3.2", "beta3.3"), 
         haz = c(rep("non-terminal", 3), 
                 rep("terminal", 3), 
                 rep("terminal | non-terminal", 3)))

SFID_coef_true <- data.frame(name = c("beta1.1", "beta1.2", "beta1.3", 
                                      "beta2.1", "beta2.2", "beta2.3", 
                                      "beta3.1", "beta3.2", "beta3.3"), 
                             true_coef = c(beta1_true, beta2_true, beta3_true), 
                             haz = c(rep("non-terminal", 3), 
                                     rep("terminal", 3), 
                                     rep("terminal | non-terminal", 3)))

ggplot(data = SFID_coef, 
       aes(x = name, 
           y = beta, 
           color = haz %>% as.factor())) +
  geom_point() +
  geom_point(data = SFID_coef_true, 
             aes(x = name, 
                 y = true_coef, 
                 color = haz %>% as.factor(), 
                 shape = "True coef."),
             size = 4) +
  geom_errorbar(aes(ymin = beta - 1.96 * SE, 
                    ymax = beta + 1.96 * SE)) +
  coord_flip() +
  theme_bw() +
  theme(legend.position = "bottom") +
  labs(y = "Coefficients", 
       x = "", 
       color = "", 
       shape = "") +
  scale_shape_manual(values = c("True coef." = 8))

Baseline hazard estimation

# comparison of estimated baseline hazard and true baseline hazard 
# baseline hazard for non-terminal event 
bs_hazard_1 <- data.frame(param = rep(paste0("alpha1 = ", alpha1_true, 
                                            ", kappa1 = ", kappa1_true %>% round(2)), 
                                      length(tseq_pred)), 
                          type = "h1", 
                          time = tseq_pred, 
                          bs_hazard = dweibull(x = tseq_pred, 
                                               shape = alpha1_true, 
                                               scale = exp(-log(kappa1_true) / alpha1_true)),
                          bs_cumhazard = pweibull(q = tseq_pred, 
                                               shape = alpha1_true, 
                                               scale = exp(-log(kappa1_true) / alpha1_true)))


# baseline hazard for terminal event 
bs_hazard_2 <- data.frame(param = rep(paste0("alpha2 = ", alpha2_true, 
                                            ", kappa2 = ", kappa2_true %>% round(2)), 
                                      length(tseq_pred)), 
                          type = "h2", 
                          time = tseq_pred, 
                          bs_hazard = dweibull(x = tseq_pred, 
                                               shape = alpha2_true, 
                                               scale = exp(-log(kappa2_true) / alpha2_true)), 
                          bs_cumhazard = pweibull(q = tseq_pred, 
                                               shape = alpha2_true, 
                                               scale = exp(-log(kappa2_true) / alpha2_true)))

# conditional baseline hazard 
bs_hazard_3 <- data.frame(param = rep(paste0("alpha3 = ", alpha3_true, 
                                            ", kappa3 = ", kappa3_true %>% round(2)), 
                                      length(tseq_pred)), 
                          type = "h3", 
                          time = tseq_pred, 
                          bs_hazard = dweibull(x = tseq_pred, 
                                               shape = alpha3_true, 
                                               scale = exp(-log(kappa3_true) / alpha3_true)), 
                          bs_cumhazard = pweibull(q = tseq_pred, 
                                               shape = alpha3_true, 
                                               scale = exp(-log(kappa3_true) / alpha3_true)))

true_bs_hazards <- rbind(bs_hazard_1, 
                         bs_hazard_2,
                         bs_hazard_3)


pw_coef_est <- summary(SFID_no_t1cat)$h0_long %>% 
  as.data.frame() %>% 
  mutate(coef = rownames(summary(SFID_no_t1cat)$h0_long))

# plot baseline hazard 
pw_coef_knots <- summary(SFID_no_t1cat)$knots_mat %>% 
  as.data.frame() %>% 
  gather(key = "type", 
         value = "time") %>% 
  arrange(type, time)

pw_baseline <- cbind(pw_coef_est, 
                     pw_coef_knots) %>% 
  group_by(type) %>% 
  mutate(time_b = c(lead(time) %>% na.omit, 100), 
         beta_b = lag(beta))

ggplot(data = pw_baseline) +
  geom_point(mapping = aes(x = time, 
                           y = beta %>% exp(), 
                           color = type)) +
  geom_segment(mapping = aes(x = time, xend = time_b,
                             y = beta %>% exp(), 
                             yend = beta %>% exp(), 
                             color = type)) +
  geom_segment(mapping = aes(x = time, xend = time,
                           y = beta %>% exp(), 
                           yend = beta_b %>% exp(),
                           color = type),
             linetype = "dotted") +
  geom_line(data = true_bs_hazards,
            aes(x = time, 
                y = bs_hazard, 
                color = type)) +
  theme_bw() +
  theme(legend.position = "bottom") +
  labs(y = "baseline hazard function", 
       color = "")
## Warning: Removed 3 rows containing missing values or values outside the scale range
## (`geom_segment()`).

Risk profiles for validation data set

# n = 1000 for validation data set 
valX <- valDat %>% 
  dplyr::select(x1, x2, x3) %>% 
  as.matrix()


ValPred <- 
  SemiCompRisksFreq:::pred_risk_ID(tseq = tseq_pred,
                                   para = SFID_no_t1cat$estimate,
                                   
                                   x1new = valX, 
                                   x2new = valX, 
                                   x3new = valX,
                                   
                                   frailty = SFID_no_t1cat$frailty, 
                                   model = SFID_no_t1cat$model,
                                   nP0 = SFID_no_t1cat$nP0, 
                                   nP = SFID_no_t1cat$nP,
                                   p3tv = 0, 
                                   h3tv_basis_func = h3_b_fun, 
                                   hazard = SFID_no_t1cat$hazard,
                                   knots_list = SFID_no_t1cat$knots_list,
                                   n_quad = SFID_no_t1cat$n_quad, 
                                   quad_method = SFID_no_t1cat$quad_method,
                                   Finv = NULL, 
                                   alpha = 0.05)
trunc_time <- 100
g_list <- list()
j <- 1 

for (i in c(1, 10, 300, 500, 700, 1000)) {
  
  g <- plot_risk_profile(tseq_pred = tseq_pred, 
                         ValPred = ValPred, 
                         i = i, 
                         t_end = trunc_time)
  
  g_list[[j]] <- g 
  j <- j + 1
  
}

ggarrange(plotlist = g_list, 
          nrow = 2, ncol = 3, 
          common.legend = TRUE,
          legend = "bottom")

Restricted Mean Time Lost (RMTL)

Truncation time is 100 [years].

Restricted mean time lost over time in each state

According to Mozumber 2021.

Compute the RMTL over time for each patient and average over prediction to obtain mean predicted RMTL for each state over time.

The population averaged version follows below.

On a patient specific level,

g_list <- list()
j <- 1 

for (i in c(1, 10, 300, 500, 700, 1000)) {
  
  g <- plot_RMTL_profile(tseq_pred = tseq_pred, 
                         ValPred = ValPred, 
                         i = i)
  
  g_list[[j]] <- g 
  j <- j + 1
  
}

ggarrange(plotlist = g_list, 
          nrow = 2, ncol = 3, 
          common.legend = TRUE,
          legend = "bottom")

Scatterplots for predicted and observed RMTL due to non-terminal event and / or death

Scatterplots of observed and predicted time-to-event for all three states. For non-observed state, have marginal distribution.

(1) Time to death

(2) Time to non-terminal event

Population-averaged Cumulative Incidence Curve based on Illness-Death model and alternative estimation methods

Competing risk framework

Naively, mortality and the non-terminal endpoint could be analyzed as competing endpoints. The Aalen-Johansen estimator estimates nonparametric Cumulative Incidence Curves for mortality and the non-terminal endpoint. Alternatively, we could estimate a competing risk Cox regression.
The survival probability for the composite endpoint, describing the first occurrence of either death or the non-terminal event, is estimated as a Kaplan Meier survival curve.

The Aalen-Johansen CIC estimate describes (1) the cumulative probability of dying (D) before experiencing the non-terminal event and (2) the cumulative probability of experiencing the non-terminal (NT) event before death.

For some \(t \in [0, \tau]\), \[\begin{align} CIC_D(t) & = P(T_D \leq t; T_{NT} > T_D) \\ & = P(T_D \leq t; \delta_{NT} == 0) \\ \\ CIC_{NT}(t) & = P(T_{NT} \leq t; T_{D} > T_{NT}) \end{align}\]

Thinking about the CICs, we can further express the CIC for the non-terminal endpoint in terms of mortality happening subsequentially or not, \[\begin{align} CIC_{NT}(t) & = P(T_{NT} \leq t; T_{D} > T_{NT}) \\ & = P(T_{NT} \leq t; T_D \leq t; T_{D} > T_{NT}) + P(T_{NT} \leq t; T_D > t; T_{D} > T_{NT}) \end{align}\]

The shared-frailty illness death model computes the marginal state probabilities, cumulative up to a certain time point. Marginal refers to the marginalization over the frailty distribution. Patients can transition out of state probabilities, as they move from Healthy to Ill and / or Dead over time. For some \(t \in [0, \tau]\), \[\begin{align} P(\text{Terminal only before/at } t) & = P(T_D \leq t; T_{NT} > T_D; \delta_{NT} = 0) \\ & = P(T_D \leq t; T_{NT} > T_D) \\ & = CIC_D(t) \\ \\ P(\text{Non-terminal only before/at} t) & = P(T_{NT} \leq t; T_{D} > T_{NT}; \delta_D = 0) \\ P(\text{Both events happening at } t) & = P(T_{NT} \leq t; T_D \leq t; T_{D} > T_{NT}) \\ \end{align}\]

The CIC for mortality is equivalent to the mortality only probability, estimated by the illness death model, as any death following after the occurrence of the non-terminal event is omitted in the competing risk framework.

The CIC for the non-terminal event is the sum of the non-terminal only and terminal / non-terminal cumulative state probability. \[\begin{align} CIC_{NT}(t) & = P(\text{Non-terminal only before/at } t) + P(\text{Both events happening before/at } t) \end{align}\]

Alternatively, the CIC can be estimated based on parametric competing risk models, accounting for the same covariates as the Illness-Death model, based on (1) Fine and Gray subdistribution hazard function and (2) cause-specific Cox proportional hazard functions.

## 
##    0    1    2 
##  280 1219  501
## 
##    0    1    2 
##  280  501 1219
##    
##       0   1
##   0 280 501
##   1 325 894

Univariate perspective on death

The competing risk framework omits any death happening post non-terminal event. If we only consider death, disregarding the non-terminal event, we might end up with observing more deaths.

Consider the marginal probability of death, which contains death without and with having experienced the non-terminal event. It is the sum of terminal event only and the conditional probability of experiencing the terminal event post non-terminal event at some time point prior to \(t\). We consider the marginal probabilities, marginalized over the frailty distribution, by inserting (according to the marginalized distribution derived above somewhere), \(\gamma_i = 1\). The derivation below shows the marginal probability of mortality equals the sum of the probability of terminal only and both events happening. \[\begin{align} P(T_D \leq t | x_i) & = P(T_D \leq t, \delta_{NT} = 0 | x_i) + P(T_D \leq t, \delta_{NT} = 1 | x_i) \\ & = CIC_D(t | x_i) + \int_{0}^t \int_{0}^u \lambda_1(s | x_i) \cdot \exp(-\Lambda_1(s | x_i) - \Lambda_2(s | x_i)) \cdot \lambda_3(u | x_i) \exp(-\Lambda_3(u | x_i, s)) \partial s \partial u \\ & = CIC_D(t | x_i) + P(T_D \leq t, T_{NT} \leq t | x_i) \end{align}\]

We estimate the marginal mortality probability based on a Cox proportional hazard model and nonparametrically via Kaplan-Meier.

##    
##       0   1
##   0 280 501
##   1 325 894
## .
##    0    1 
## 1499  501

Classification rates

Correct classification rate

Consider all 4 states equal and compute the correct classification rate at each time point that is predicted.

The overall correct classification rate is determined as the sum over all 4 states. The states can be acknowledged to have different clinical relevance by introducing weights.

The correct classification rate depends on prevalence, as it is the prevalence weighted average of state-specific classification rate (Li 2008). We compute the total correct classification rate marginally, by binarizing.

Entropy (log loss) for multiclass predictions

Entropy is \(E = -\sum_i \sum_{k = 1}^4 \delta_{k, i} \log(P_k(x_i))\). As time progressed, the entropy increases, showing a decreasing prediction accuracy.

ROC surface

Considering the univariate specificity and sensitivity for each state, for a three-dimensional tuple of thresholds, creates a three-dimensional ROC surface. The volume under the manifold (HUM) or volume under the ROC surface (VUS) has been proposed as a joint assessment of prediction accuracy (see Lee).

It describes the probability of correct classification, according to some decision rule. Decision rules for multiple states must be correctly generalized from standard binary decision rules. Multiple different definitions are available (Li 2008).

Depending on the decision rule, the UVS / HUM has a certain interpretation. All three decision rules discussed in Li 2008 provide equivalent HUM interpretation. It describes how likely class-probabilities follow certain sequence aligning with a (potentially arbitrary) sequence of classes. The probability to obtain the correct ordering / ranking based on a coin toss is \(1 / M!\).

Inference on HUM can be performed based on U-statistic theory.

For now, the data is simulated without censoring. The outcome is known for every patient.

##   time AUC_nonterm  AUC_term  AUC_both AUC_neither      HUM
## 1   10   0.5727653 0.7098790 0.7849094   0.6929807 2.760534
## 2   50   0.7351586 0.6840825 0.6660859   0.6745007 2.759828
## 3   90   0.7551184 0.6568143 0.6310468   0.6793980 2.722378

Multiclass Matthew correlation coefficient

The Matthew correlation coefficient (MCC) measures the correlation between the predicted and observed probability for a binary classification. It quantifies specificity and sensitivity, as well as precision (correct identification among ill patients) and the true negative probability among the healthy patients. It has been argued to provide a more relevant impression of the quality of a decision rule as it takes its predictive accuracy into account.

##   time       MCC
## 1   10 0.1819033
## 2   50 0.1998995
## 3   90 0.2054696

Expected calibration error

The expected calibration error is computed for a sequence of landmark time points, as the empirical mean for all patients.

We currently have no censoring. A potential expansion could be an IPW weighted estimator.

Healthy and Alive

The probability for neither event happening equals the survival probability for the composite endpoint, consisting of the non-terminal outcome and mortality, whichever came earlier. \[\begin{align} P(\text{Neither event is happening before/at } t) & = P(T_D > t; T_{NT} > t) \\ & = P(\min(T_D, T_{NT}) > t) \\ & = S(\tilde{T} > t) \end{align}\] where \(\tilde{T} = \min(T_D, T_{NT})\) is the composite endpoint.

To establish the prediction accuracy, we compare the predicted probability for being healthy and alive with the true probability according to the data simulation. We know the true hazard functions, under a semi-Markov assumption \[\begin{align} \lambda_1(t | x_i) & = \lambda_{01}(t) \cdot \gamma_i \cdot \exp(\beta_1^Tx_i) \\ \lambda_2(t | x_i) & = \lambda_{02}(t) \cdot \gamma_i \cdot \exp(\beta_2^Tx_i) \\ \lambda_3(\delta t | x_i) & = \lambda_{03}(\delta t) \cdot \gamma_i \cdot \exp(\beta_3^Tx_i) \end{align}\] Marginalizing over the frailty distribution, \(\gamma_i \sim \Gamma(\theta, \theta)\), where \(\theta = 1\), \[\begin{align} \lambda_{m, j} (t | x_i) & = \lambda_{01}(t) \cdot \gamma_i \cdot \exp(\beta_1^Tx_i) \cdot E(\gamma_i | \theta = 1) \\ & = \lambda_{01}(t) \cdot \gamma_i \cdot \exp(\beta_1^Tx_i) \cdot 1 \end{align}\] And so the marginal quantities follow with \(\gamma_i = 1\).

And in turn the true cumulative hazard functions, \[\begin{align} \Lambda_1(t | x_i) & = \gamma_i \int_0^{t} \lambda_1(u | x_i) \partial u \\ \Lambda_2(t | x_i) & = \gamma_i \int_0^{t} \lambda_2(u | x_i) \partial u \\ \Lambda_3(t | x_i) & = \gamma_i \int_0^{t} \lambda_3(\delta u | x_i) \partial \delta u \end{align}\]

The cumulative probability for healthy and alive can be expressed as \[\begin{align} P(\text{Healthy and Alive before / at } t) &= P(T_D > t, T_{NT} > t) \\ & = \exp\left(-\gamma_i \cdot \left(\Lambda_1(t | x_i) + \Lambda_2(t | x_i) \right) \right) \\ & = \exp\left( - \gamma_i \cdot \exp(\beta_1^Tx_i) \int_0^{t} \lambda_{01}(u) \partial u \right) \cdot \exp\left( -\gamma_i \cdot \exp(\beta_2^Tx_i) \int_0^{t} \lambda_{02}(u) \partial u \right) \\ \end{align}\]

The form of the baseline hazard is assumed to be known when fitting the Illness-Death model. In any case, \(\int_0^{t} \lambda_{01}(u) \partial u\) is the cumulative probability for the baseline hazard.

For a Weibull baseline hazard, \(\int_0^{t} \lambda_{01}(u) \partial u = P(W \leq t)\) where \(W \sim \text{Weibull}(\alpha, \kappa)\). For piecewise constant hazard, \(\int_0^{t} \lambda_{01}(u) \partial u = \sum_{k_j \leq t} (k_j - k_{j-1}) \delta_j\), where \(k_j\) denotes the time knots and \(\delta_j\) the hazard level.

We can estimate either the true hazard function on a patient specific level including the shared multiplicative frailty () or the true marginal hazard function ().

We consider classification over time and according to

We consider the classification accuracy at a discrete set of (ideally clinically motivated) landmark time points, \(t = \{50, 75, 100\}\).

SCR_HA_list <- list()
j <- 1

C_COX <- c()
for (t in tseq_acc) {
  
   SCR_HA_list[[j]] <- valDat %>% 
    mutate(current_class = case_when(time1 > t & time2 > t ~ "Healthy and Alive", 
                                     time1 <= t & event1 == 1 & time2 > t ~ "Non-Terminal", 
                                     time1 <= t & event1 == 1 & 
                                       time2 <= t & event2 == 1 ~ "Both", 
                                     time2 <= t & event1 == 0 & event2 == 1 ~ "Terminal", 
                                     .default = "Censored"), 
           PP_nonterm = ValPred$p_nonterm_only_marg[t + 1, ], 
           PP_term = ValPred$p_term_only_marg[t + 1, ], 
           PP_both = ValPred$p_both_marg[t + 1, ], 
           PP_ha = ValPred$p_neither_marg[t + 1, ]
           ) %>% 
     mutate(PP_nonterm_stand = scale(PP_nonterm), 
            PP_term_stand = scale(PP_term), 
            PP_both_stand = scale(PP_both), 
            PP_ha_stand = scale(PP_ha)) %>% 
    group_by(id) %>% 
    mutate(PP_max = max(PP_nonterm, PP_term, PP_both, PP_ha),
           PP_max_stand = max(PP_nonterm_stand, PP_term_stand, PP_both_stand, PP_ha_stand)) %>% 
    mutate(pred_class = case_when(PP_nonterm == PP_max ~ "Non-Terminal", 
                                  PP_term == PP_max ~ "Terminal", 
                                  PP_both == PP_max ~ "Both", 
                                  PP_ha == PP_max ~ "Healthy and Alive", 
                                  .default = NA), 
           pred_class_stand = case_when(PP_nonterm_stand == PP_max_stand ~ "Non-Terminal", 
                                  PP_term_stand == PP_max_stand ~ "Terminal", 
                                  PP_both_stand == PP_max_stand ~ "Both", 
                                  PP_ha_stand == PP_max_stand ~ "Healthy and Alive", 
                                  .default = NA)
           ) %>% 
     ungroup() %>% 
     dplyr::select(id, current_class, pred_class, pred_class_stand, PP_ha, PP_ha_stand)
   
   j <- j + 1
  
   
   # C-index for composite endpoint survival analysis 
   trainDat_temp <- trainDat %>% filter(time_comp <= t)
   CoxPH_comp_model <- coxph(Surv(compT, compS) ~ x1 + x2 + x3,
                             trainDat_temp)
   C_COX <- c(C_COX, 
              survConcordance(Surv(compT, compS) ~ predict(CoxPH_comp_model), 
                   data = trainDat_temp)$concordance)
   
}


SCR_HA_pred <- list.rbind(SCR_HA_list)  %>% 
  mutate(time = rep(tseq_acc, each = 1000), 
         pred_HA = ifelse(pred_class == "Healthy and Alive", 
                          1, 0), 
         pred_HA_stand = ifelse(pred_class_stand == "Healthy and Alive", 
                          1, 0), 
         true_HA = ifelse(current_class == "Healthy and Alive", 
                          1, 0)) 

C-index

The C-index describes the number of concordant pairs of observed and predicted classification compared to all pairs. In terms of binary classification, it equals counting the number of correctly classified (correct positive and correct negative) and dividing it by the sample size.

Calibration Plot

For each state, we plot the observed indicator and predicted cumulative probability at a specific time point. There is no clear separation between the patients who are healthy and alive and the non-healthy and/or dead patients.

Brier score

The Brier score is a time dependent metric, frequently used to measure prediction accuracy. \[\begin{align} \text{Brier} & = \frac{1}{n} \sum \left(P(T_D > t, T_{NT} > t | x_i) - \delta(t) \right)^2 \end{align}\] where the indicator \(\delta(t) = (1 - \delta_D(t)) \cdot (1 - \delta_{NT}(t))\) takes the value 1 if the patient is still healthy and alive.

It is estimated on the validation data.

Specificity and Sensitivity over time

At landmark points, specificity (correct positive) and sensitivity (correct negative) are computed. It estimates the Inverse Probability of Censoring Weighting (IPCW) estimates of Cumulative/Dynamic time-dependent ROC curve.

Clinical communication

For a certain strata; how many people would be considered to be in which state at the end of some period? Shading of people over time??

N_viz <- 100
N_viz_patients <- list()
N_viz_patients[[1]] <- data.frame(n_nonterm_diff = 0, 
                                  n_term_diff = 0, 
                                  n_both_diff = 0,
                                  
                                  n_nonterm = 0, 
                                  n_term = 0, 
                                  n_both = 0,
                                  
                                  n_ha = N_viz, 
                                  time = 0)

j <- 1 
for (t in tseq_acc) {
  
  N_viz_temp <- valDat %>% 
    mutate(current_class = case_when(time1 > t & time2 > t ~ "Healthy and Alive", 
                                     time1 <= t & event1 == 1 & time2 > t ~ "Non-Terminal", 
                                     time1 <= t & event1 == 1 & 
                                       time2 <= t & event2 == 1 ~ "Both", 
                                     time2 <= t & event1 == 0 & event2 == 1 ~ "Terminal", 
                                     .default = "Censored"), 
           PP_nonterm = ValPred$p_nonterm_only_marg[t + 1, ], 
           PP_term = ValPred$p_term_only_marg[t + 1, ], 
           PP_both = ValPred$p_both_marg[t + 1, ]
           ) %>% 
    summarize(n_nonterm = mean(PP_nonterm) * N_viz, 
              n_term = mean(PP_term) * N_viz, 
              n_both = mean(PP_both) * N_viz)
    
  # change in number of patients in certain state 
  N_viz_patients[[j + 1]] <- data.frame(
    n_nonterm_diff = N_viz_temp$n_nonterm - N_viz_patients[[j]]$n_nonterm_diff, 
    n_term_diff = N_viz_temp$n_term - N_viz_patients[[j]]$n_term_diff, 
    n_both_diff = N_viz_temp$n_both - N_viz_patients[[j]]$n_both_diff, 
    
    n_nonterm = N_viz_temp$n_nonterm, 
    n_term = N_viz_temp$n_term, 
    n_both = N_viz_temp$n_both
  ) %>% 
    mutate(n_ha = N_viz - N_viz_temp$n_nonterm - N_viz_temp$n_term - N_viz_temp$n_both, 
           time = t)
  j <- j + 1 
}

N_viz_patients_df <- N_viz_patients %>% 
  list.rbind() %>% 
  dplyr::select(-c(n_nonterm_diff, n_term_diff, n_both_diff)) %>% 
  gather(key = "outcome", 
         value = "num_pat", 
         -time) %>% 
  mutate(outcome_f = factor(outcome, 
                            levels = c("n_term",
                                       "n_both", 
                                       "n_nonterm", 
                                       "n_ha"), 
                            labels = c("Dead", 
                                       "Ill + Dead", 
                                       "Ill + Alive", 
                                       "Healthy + Alive")))

# for now, draw barplot to show the change in patients 
ggplot(data = N_viz_patients_df, 
       aes(x = time,
           fill = outcome_f, 
           y = num_pat)) +
  geom_bar(stat = "identity") +
  theme_bw() +
  theme(legend.position = "bottom") +
  labs(x = "Time", 
       y = "No. of patients", 
       title = "Among 100 patients, over time, how many are suffering from non-terminal event, dead, dead post disease, or healthy and alive?", 
       fill = "") + 
  scale_fill_manual(values = c("Ill + Dead" = cb_purple,
                               "Ill + Alive" = cb_blue, 
                               "Dead" = cb_red, 
                               "Healthy + Alive" = cb_grey))